#!/usr/bin/env python3

# Python script to run and analyse MMS test
#
# Tests a range of different schemes
#

#requires: all_tests

from __future__ import division
from __future__ import print_function
from builtins import zip
from builtins import str

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boututils.datafile import DataFile
from boutdata.collect import collect

from numpy import sqrt, max, abs, mean, array, log, pi

from os.path import join, isfile

import matplotlib.pyplot as plt

import pickle

from sys import stdout

## Get the tokamak shape
## Note: This must be the same as in the mms.py code
from boutdata.mms import SimpleTokamak, x, y
from sympy import sin, cos
shape = SimpleTokamak()
m = shape.metric()

## Add equilibrium profiles
B0 = m.B
MU0 = 4.0e-7*pi

J0 = 1 - x - sin(x*pi)**2 * cos(y)   # Parallel current
P0 = 2 + cos(x*pi)   # Pressure pedestal
bxcvz = -(1./shape.Rxy)**2*cos(y)  # Curvature

# Normalisation
Lbar = 1.
Bbar = 1.
J0 = - J0 * shape.Bxy / (MU0 * Lbar)  # Turn into A/m^2
P0 = P0 * Bbar**2 / (2.0*MU0)  # Pascals

shape.add(P0, "pressure")
shape.add(J0, "Jpar0")
shape.add(bxcvz, "bxcvz")



print("Making MMS elm-pb test")
shell_safe("make > make.log")

# List of NX values to use
nxlist = [8, 16, 32, 64]#, 128]#, 256]
nprocs = [4,  4,  8,  8,  8, 8] 

success = True

varlist = ["P", "Psi", "U"]
markers = ['bo', 'r^', "gs"]
labels = [r'$P$', r'$\psi$', r'$U$']

error_2 = {}
error_inf = {}
for var in varlist:
    error_2[var]   = []  # The L2 error (RMS)
    error_inf[var] = []  # The maximum error

running = True # Run the simulations?

for nx,nproc in zip(nxlist, nprocs):
    directory = "grid%d" % nx
    
    if running:
        # Generate a mesh file if needed
        
        filename = "grid%d.nc" % nx
    
        if isfile(filename):
            print("Grid file '%s' already exists" % filename)
        else:
            print("Creating grid file '%s'" % filename)
            f = DataFile(filename, create=True)
            shape.write(nx,nx, f)
            f.close()
    
        args = " MZ="+str(nx)+" grid="+filename #+" solver:timestep="+str(0.1/nx)
        
        print("  Running with " + args)

        

        # Run each simulation in a different directory
        shell("mkdir " + directory)
        shell_safe("cp data/BOUT.inp "+directory)
        # Delete old data
        shell("rm %s/BOUT.dmp.*.nc" % directory)
    
        args = " -d "+directory+" "+args
    
        # Command to run
        cmd = "./elm_pb "+args
        # Launch using MPI
        s, out = launch_safe(cmd, nproc=nproc, pipe=True)
        
        # Save output to log file
        f = open("run.log."+str(nx), "w")
        f.write(out)
        f.close()
        
    for var in varlist:
        # Collect data
        E = collect("E_"+var, tind=[1,1], info=False, path=directory)
        E = E[:,2:-2, :,:]
        
        # Average error over domain
        l2 = sqrt(mean(E**2))
        linf = max(abs( E ))
    
        error_2[var].append( l2 )
        error_inf[var].append( linf )

        print("%s : l-2 %f l-inf %f" % (var, l2, linf))
    
    
dx = 1. / array(nxlist)

# Save data
with open("elm_pb.pkl", "wb") as output:
    pickle.dump(nxlist, output)
    pickle.dump(error_2, output)
    pickle.dump(error_inf, output)

# plot errors
plt.figure()

# Calculate convergence order
for var,mark,label in zip(varlist, markers, labels):
    plt.plot(dx, error_2[var], '-'+mark, label=label)
    plt.plot(dx, error_inf[var], '--'+mark)

    order = log(error_2[var][-1] / error_2[var][-2]) / log(dx[-1] / dx[-2])
    stdout.write("%s Convergence order = %f" % (var, order))
    
    if 1.8 < order < 2.2: # Should be second order accurate
        print("............ PASS")
    else:
        success = False
        print("............ FAIL")


plt.legend(loc="upper left")
plt.grid()
    
plt.yscale('log')
plt.xscale('log')
        
plt.xlabel(r'Mesh spacing $\delta x$')
plt.ylabel("Error norm")
    
plt.savefig("norm.pdf")

plt.show()
plt.close()

if success:
  print(" => All tests passed")
  exit(0)
else:
  print(" => Some failed tests")
  exit(1)
